{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d78ad6e7",
   "metadata": {},
   "source": [
    "# SQL 데이터베이스와 상호작용하는 에이전트\n",
    "\n",
    "이 튜토리얼에서는 **SQL 데이터베이스에 대한 질문에 답할 수 있는 에이전트**를 단계별로 구축하는 방법을 소개합니다.  \n",
    "\n",
    "SQL 쿼리를 실행하는 에이전트의 흐름은 다음과 같습니다.\n",
    "\n",
    "1. **데이터베이스 스키마 파악**: 사용 가능한 테이블 목록을 가져옵니다.\n",
    "2. **관련 테이블 선택**: 질문과 연관된 테이블을 선택합니다.\n",
    "3. **DDL 조회**: 선택된 테이블의 스키마 정의(DDL)를 가져옵니다.\n",
    "4. **쿼리 생성**: 질문과 DDL 정보에 기반하여 SQL 쿼리를 작성합니다.\n",
    "5. **쿼리 점검**: LLM을 사용하여 일반적인 오류를 검토하고 쿼리를 개선합니다.\n",
    "6. **쿼리 실행 및 오류 처리**: 데이터베이스 엔진에 쿼리를 실행하고, 오류 발생 시 수정하여 성공적으로 쿼리를 수행합니다.\n",
    "7. **응답 생성**: 쿼리 결과를 기반으로 최종 답변을 제공합니다.\n",
    "\n",
    "![](./assets/langgraph-sql-agent.png)\n",
    "\n",
    "---\n",
    "\n",
    "**주요 내용**\n",
    "\n",
    "- **데이터베이스**: SQLite 데이터베이스 설정 및 `chinook` 샘플 데이터베이스 로드  \n",
    "- **유틸리티 함수**: 에이전트 구현을 위한 유틸리티 함수 정의  \n",
    "- **도구 정의**: 데이터베이스와 상호작용하기 위한 도구 정의  \n",
    "- **워크플로우 정의**: 에이전트의 워크플로우(그래프) 정의  \n",
    "- **그래프 시각화**: 정의된 그래프 시각화  \n",
    "- **에이전트 실행**: 에이전트 실행 및 결과 확인  \n",
    "- **평가**: 에이전트 평가 및 성능 비교  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8000125",
   "metadata": {},
   "source": [
    "## 환경 설정\n",
    "\n",
    "먼저, 필요한 패키지를 설치하고 API 키를 설정합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af24e6b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# API 키를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API 키 정보 로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b57903b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install -qU langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"CH17-LangGraph-Use-Cases\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3743ff7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.models import get_model_name, LLMs\n",
    "\n",
    "MODEL_NAME = get_model_name(LLMs.GPT4o)\n",
    "print(f\"사용하는 모델명: {MODEL_NAME}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c22810c3",
   "metadata": {},
   "source": [
    "## 데이터베이스 설정\n",
    "\n",
    "이 튜토리얼에서는 SQLite 데이터베이스를 생성합니다. SQLite는 설정과 사용이 간편한 경량 데이터베이스입니다. \n",
    "\n",
    "이번 튜토리얼에서는 샘플 데이터베이스인 `chinook` 데이터베이스를 로드할 예정이며, 이는 디지털 미디어 스토어를 나타내는 샘플 데이터베이스입니다. \n",
    "\n",
    "데이터베이스에 대한 자세한 정보는 [여기](https://www.sqlitetutorial.net/sqlite-sample-database/)에서 확인할 수 있습니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33ea4a58",
   "metadata": {},
   "source": [
    "먼저, 실습에 활용할 `chinook` 데이터베이스를 다운로드 받습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc78225",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "url = \"https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db\"\n",
    "\n",
    "response = requests.get(url)\n",
    "\n",
    "if response.status_code == 200:\n",
    "    with open(\"Chinook.db\", \"wb\") as file:\n",
    "        file.write(response.content)\n",
    "    print(\"File downloaded and saved as Chinook.db\")\n",
    "else:\n",
    "    print(f\"Failed to download the file. Status code: {response.status_code}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "013e66ea",
   "metadata": {},
   "source": [
    "다음은 다운로드 받은 `chinook` 데이터베이스를 사용하여 `SQLDatabase` 도구를 생성하고 샘플 쿼리인 `\"SELECT * FROM Artist LIMIT 5;\"`를 실행합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "766ea943",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "# SQLite 데이터베이스 파일에서 SQLDatabase 인스턴스 생성\n",
    "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
    "\n",
    "# DB dialect 출력(sqlite)\n",
    "print(db.dialect)\n",
    "\n",
    "# 데이터베이스에서 사용 가능한 테이블 이름 목록 출력\n",
    "print(db.get_usable_table_names())\n",
    "\n",
    "# SQL 쿼리 실행\n",
    "db.run(\"SELECT * FROM Artist LIMIT 5;\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf1ea840",
   "metadata": {},
   "source": [
    "## 유틸리티 함수\n",
    "\n",
    "에이전트 구현을 돕기 위해 몇 가지 유틸리티 함수를 정의합니다. \n",
    "\n",
    "특히, `ToolNode`를 **오류 처리** 와 **에이전트에 오류를 전달하는 기능** 을 포함하여 래핑합니다. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5ba2ff50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any\n",
    "\n",
    "from langchain_core.messages import ToolMessage\n",
    "from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks\n",
    "from langgraph.prebuilt import ToolNode\n",
    "\n",
    "\n",
    "# 오류 처리 함수\n",
    "def handle_tool_error(state) -> dict:\n",
    "    # 오류 정보 조회\n",
    "    error = state.get(\"error\")\n",
    "    # 도구 정보 조회\n",
    "    tool_calls = state[\"messages\"][-1].tool_calls\n",
    "    # ToolMessage 로 래핑 후 반환\n",
    "    return {\n",
    "        \"messages\": [\n",
    "            ToolMessage(\n",
    "                content=f\"Here is the error: {repr(error)}\\n\\nPlease fix your mistakes.\",\n",
    "                tool_call_id=tc[\"id\"],\n",
    "            )\n",
    "            for tc in tool_calls\n",
    "        ]\n",
    "    }\n",
    "\n",
    "\n",
    "# 오류를 처리하고 에이전트에 오류를 전달하기 위한 ToolNode 생성\n",
    "def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:\n",
    "    \"\"\"\n",
    "    Create a ToolNode with a fallback to handle errors and surface them to the agent.\n",
    "    \"\"\"\n",
    "    # 오류 발생 시 대체 동작을 정의하여 ToolNode에 추가\n",
    "    return ToolNode(tools).with_fallbacks(\n",
    "        [RunnableLambda(handle_tool_error)], exception_key=\"error\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3be7a514",
   "metadata": {},
   "source": [
    "## SQL 쿼리 실행 도구 \n",
    "\n",
    "에이전트가 데이터베이스와 상호작용할 수 있도록 몇 가지 도구를 정의합니다.\n",
    "\n",
    "1. `list_tables_tool`: 데이터베이스에서 사용 가능한 테이블을 가져옵니다.\n",
    "2. `get_schema_tool`: 테이블의 DDL을 가져옵니다.\n",
    "3. `db_query_tool`: 쿼리를 실행하고 결과를 가져오거나 쿼리가 실패할 경우 오류 메시지를 반환합니다.\n",
    "\n",
    "**참고**\n",
    "\n",
    "- DDL(데이터 정의 언어, **Data Definition Language**)은 데이터베이스의 구조와 스키마를 정의하거나 수정하는 데 사용되는 SQL 명령어들을 지칭합니다. 주로 테이블, 인덱스, 뷰, 스키마 등의 데이터베이스 객체를 생성, 수정, 삭제할 때 사용됩니다.\n",
    "\n",
    "주요 DDL 명령어\n",
    "\n",
    "- **`CREATE`**: 데이터베이스 객체를 생성합니다.\n",
    "  - 예: `CREATE TABLE users (id INT, name VARCHAR(100));`\n",
    "- **`ALTER`**: 기존 데이터베이스 객체를 수정합니다.\n",
    "  - 예: `ALTER TABLE users ADD COLUMN email VARCHAR(100);`\n",
    "- **`DROP`**: 데이터베이스 객체를 삭제합니다.\n",
    "  - 예: `DROP TABLE users;`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be594e8c",
   "metadata": {},
   "source": [
    "### 데이터베이스 쿼리 관련 도구\n",
    "\n",
    "다음은 SQL database와 상호작용하기 위한 `SQLDatabaseToolkit` 도구 목록입니다.\n",
    "\n",
    "**QuerySQLDataBaseTool**\n",
    "\n",
    "- **기능**: SQL query 실행 및 결과 반환\n",
    "- **Input**: 정확한 SQL query\n",
    "- **Output**: Database 결과 또는 error message\n",
    "- **Error 처리**:\n",
    "  - Query 오류 발생 시 재작성 및 재시도\n",
    "  - `Unknown column` 오류 시 `sql_db_schema`로 정확한 table fields 확인\n",
    "\n",
    "**InfoSQLDatabaseTool**\n",
    "\n",
    "- **기능**: Table schema 및 sample data 조회\n",
    "- **Input**: 콤마로 구분된 table 목록\n",
    "- **사용 예시**: `table1, table2, table3`\n",
    "- **주의사항**: `sql_db_list_tables`로 table 존재 여부 사전 확인 필요\n",
    "\n",
    "**ListSQLDatabaseTool**\n",
    "\n",
    "- **기능**: Database 내 table 목록 조회\n",
    "\n",
    "**QuerySQLCheckerTool**\n",
    "\n",
    "- **기능**: Query 실행 전 유효성 검사\n",
    "- **검사 항목**:\n",
    "  - NULL 값과 NOT IN 사용\n",
    "  - UNION vs UNION ALL 적절성\n",
    "  - BETWEEN 범위 설정\n",
    "  - Data type 일치 여부\n",
    "  - Identifier 인용 적절성\n",
    "  - Function argument 수\n",
    "  - Data type casting\n",
    "  - Join column 정확성\n",
    "- **특징**: GPT-4 model 기반 검증 수행"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d301860",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# SQLDatabaseToolkit 생성\n",
    "toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model=MODEL_NAME))\n",
    "\n",
    "# SQLDatabaseToolkit에서 사용 가능한 도구 목록\n",
    "tools = toolkit.get_tools()\n",
    "tools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c28a0eeb",
   "metadata": {},
   "source": [
    "아래는 `list_tables_tool` 과 `get_schema_tool` 에 대한 실행 예시입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19bb31d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 데이터베이스에서 사용 가능한 테이블을 나열하는 도구 선택\n",
    "list_tables_tool = next(tool for tool in tools if tool.name == \"sql_db_list_tables\")\n",
    "\n",
    "# 특정 테이블의 DDL을 가져오는 도구 선택\n",
    "get_schema_tool = next(tool for tool in tools if tool.name == \"sql_db_schema\")\n",
    "\n",
    "# 데이터베이스의 모든 테이블 목록 출력\n",
    "print(list_tables_tool.invoke(\"\"))\n",
    "\n",
    "# Artist 테이블의 DDL 정보 출력\n",
    "print(get_schema_tool.invoke(\"Artist\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30f7d7e2",
   "metadata": {},
   "source": [
    "다음은 `db_query_tool` 을 정의합니다. \n",
    "\n",
    "`db_query_tool`의 경우, 데이터베이스에 대해 쿼리를 실행하고 결과를 반환합니다.\n",
    "\n",
    "만약, error 가 발생하면 오류 메시지를 반환합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cacbfa64",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import tool\n",
    "\n",
    "\n",
    "# Query 실행 도구\n",
    "@tool\n",
    "def db_query_tool(query: str) -> str:\n",
    "    \"\"\"\n",
    "    Run SQL queries against a database and return results\n",
    "    Returns an error message if the query is incorrect\n",
    "    If an error is returned, rewrite the query, check, and retry\n",
    "    \"\"\"\n",
    "    # 쿼리 실행\n",
    "    result = db.run_no_throw(query)\n",
    "\n",
    "    # 오류: 결과가 없으면 오류 메시지 반환\n",
    "    if not result:\n",
    "        return \"Error: Query failed. Please rewrite your query and try again.\"\n",
    "    # 정상: 쿼리 실행 결과 반환\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da6c1cab",
   "metadata": {},
   "source": [
    "정상 실행된 경우"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7e6af4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Artist 테이블에서 상위 10개 행 선택 및 실행 결과 출력\n",
    "print(db_query_tool.invoke(\"SELECT * FROM Artist LIMIT 10;\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3cd6fd1",
   "metadata": {},
   "source": [
    "오류가 발생한 경우 예시"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4769c8af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Artist 테이블에서 상위 10개 행 선택 및 실행 결과 출력\n",
    "print(db_query_tool.invoke(\"SELECT * FROM Artist LIMITS 10;\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cdd597b",
   "metadata": {},
   "source": [
    "### SQL 쿼리 점검(SQL Query Checker)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46308526",
   "metadata": {},
   "source": [
    "다음은, SQL 쿼리에서 일반적인 실수를 점검하기 위해 LLM을 활용할 예정입니다. \n",
    "\n",
    "이는 엄밀히 말하면 도구는 아니지만, 이후 워크플로우에 노드로 추가될 것입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f608b2aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# SQL 쿼리의 일반적인 실수를 점검하기 위한 시스템 메시지 정의\n",
    "query_check_system = \"\"\"You are a SQL expert with a strong attention to detail.\n",
    "Double check the SQLite query for common mistakes, including:\n",
    "- Using NOT IN with NULL values\n",
    "- Using UNION when UNION ALL should have been used\n",
    "- Using BETWEEN for exclusive ranges\n",
    "- Data type mismatch in predicates\n",
    "- Properly quoting identifiers\n",
    "- Using the correct number of arguments for functions\n",
    "- Casting to the correct data type\n",
    "- Using the proper columns for joins\n",
    "\n",
    "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n",
    "\n",
    "You will call the appropriate tool to execute the query after running this check.\"\"\"\n",
    "\n",
    "# 프롬프트 생성\n",
    "query_check_prompt = ChatPromptTemplate.from_messages(\n",
    "    [(\"system\", query_check_system), (\"placeholder\", \"{messages}\")]\n",
    ")\n",
    "\n",
    "# Query Checker 체인 생성\n",
    "query_check = query_check_prompt | ChatOpenAI(\n",
    "    model=MODEL_NAME, temperature=0\n",
    ").bind_tools([db_query_tool], tool_choice=\"db_query_tool\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cd06a6a",
   "metadata": {},
   "source": [
    "잘못된 쿼리를 날려 호출하여 결과가 잘 수정되었는지 확인합니다.\n",
    "\n",
    "**참고**\n",
    "\n",
    "- `LIMIT` 대신 `LIMITS` 을 사용하여 쿼리를 날렸습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2355b32f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 사용자 메시지를 사용하여 쿼리 점검 노드 실행\n",
    "response = query_check.invoke(\n",
    "    {\"messages\": [(\"user\", \"SELECT * FROM Artist LIMITS 10;\")]}\n",
    ")\n",
    "print(response.tool_calls[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "674c8ce9",
   "metadata": {},
   "source": [
    "결과는 잘 수정되었습니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db713316",
   "metadata": {},
   "source": [
    "## 그래프 정의\n",
    "\n",
    "에이전트의 워크플로우를 정의합니다. \n",
    "\n",
    "에이전트는 먼저 `list_tables_tool`을 강제로 호출하여 데이터베이스에서 사용 가능한 테이블을 가져온 후, 튜토리얼 초반에 언급된 단계를 따릅니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "32474a35",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated, Literal\n",
    "\n",
    "from langchain_core.messages import AIMessage\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "from pydantic import BaseModel, Field\n",
    "from typing_extensions import TypedDict\n",
    "\n",
    "from langgraph.graph import END, StateGraph, START\n",
    "from langgraph.graph.message import AnyMessage, add_messages\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "\n",
    "\n",
    "# 에이전트의 상태 정의\n",
    "class State(TypedDict):\n",
    "    messages: Annotated[list[AnyMessage], add_messages]\n",
    "\n",
    "\n",
    "# 새로운 그래프 정의\n",
    "workflow = StateGraph(State)\n",
    "\n",
    "\n",
    "# 첫 번째 도구 호출을 위한 노드 추가\n",
    "def first_tool_call(state: State) -> dict[str, list[AIMessage]]:\n",
    "    return {\n",
    "        \"messages\": [\n",
    "            AIMessage(\n",
    "                content=\"\",\n",
    "                tool_calls=[\n",
    "                    {\n",
    "                        \"name\": \"sql_db_list_tables\",\n",
    "                        \"args\": {},\n",
    "                        \"id\": \"initial_tool_call_abc123\",\n",
    "                    }\n",
    "                ],\n",
    "            )\n",
    "        ]\n",
    "    }\n",
    "\n",
    "\n",
    "# 쿼리의 정확성을 모델로 점검하기 위한 함수 정의\n",
    "def model_check_query(state: State) -> dict[str, list[AIMessage]]:\n",
    "    \"\"\"\n",
    "    Use this tool to check that your query is correct before you run it\n",
    "    \"\"\"\n",
    "    return {\"messages\": [query_check.invoke({\"messages\": [state[\"messages\"][-1]]})]}\n",
    "\n",
    "\n",
    "# 첫 번째 도구 호출 노드 추가\n",
    "workflow.add_node(\"first_tool_call\", first_tool_call)\n",
    "\n",
    "# 첫 번째 두 도구를 위한 노드 추가\n",
    "workflow.add_node(\n",
    "    \"list_tables_tool\", create_tool_node_with_fallback([list_tables_tool])\n",
    ")\n",
    "workflow.add_node(\"get_schema_tool\", create_tool_node_with_fallback([get_schema_tool]))\n",
    "\n",
    "# 질문과 사용 가능한 테이블을 기반으로 관련 테이블을 선택하는 모델 노드 추가\n",
    "model_get_schema = ChatOpenAI(model=MODEL_NAME, temperature=0).bind_tools(\n",
    "    [get_schema_tool]\n",
    ")\n",
    "workflow.add_node(\n",
    "    \"model_get_schema\",\n",
    "    lambda state: {\n",
    "        \"messages\": [model_get_schema.invoke(state[\"messages\"])],\n",
    "    },\n",
    ")\n",
    "\n",
    "\n",
    "# 최종 상태를 나타내는 도구 설명\n",
    "class SubmitFinalAnswer(BaseModel):\n",
    "    \"\"\"쿼리 결과를 기반으로 사용자에게 최종 답변 제출\"\"\"\n",
    "\n",
    "    final_answer: str = Field(..., description=\"The final answer to the user\")\n",
    "\n",
    "\n",
    "# 질문과 스키마를 기반으로 쿼리를 생성하기 위한 모델 노드 추가\n",
    "QUERY_GEN_INSTRUCTION = \"\"\"You are a SQL expert with a strong attention to detail.\n",
    "\n",
    "You can define SQL queries, analyze queries results and interpretate query results to response an answer.\n",
    "\n",
    "Read the messages bellow and identify the user question, table schemas, query statement and query result, or error if they exist.\n",
    "\n",
    "1. If there's not any query result that make sense to answer the question, create a syntactically correct SQLite query to answer the user question. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
    "\n",
    "2. If you create a query, response ONLY the query statement. For example, \"SELECT id, name FROM pets;\"\n",
    "\n",
    "3. If a query was already executed, but there was an error. Response with the same error message you found. For example: \"Error: Pets table doesn't exist\"\n",
    "\n",
    "4. If a query was already executed successfully interpretate the response and answer the question following this pattern: Answer: <<question answer>>. For example: \"Answer: There three cats registered as adopted\"\n",
    "\"\"\"\n",
    "\n",
    "query_gen_prompt = ChatPromptTemplate.from_messages(\n",
    "    [(\"system\", QUERY_GEN_INSTRUCTION), (\"placeholder\", \"{messages}\")]\n",
    ")\n",
    "query_gen = query_gen_prompt | ChatOpenAI(model=MODEL_NAME, temperature=0).bind_tools(\n",
    "    [SubmitFinalAnswer, model_check_query]\n",
    ")\n",
    "\n",
    "\n",
    "# 조건부 에지 정의\n",
    "def should_continue(state: State) -> Literal[END, \"correct_query\", \"query_gen\"]:\n",
    "    messages = state[\"messages\"]\n",
    "\n",
    "    last_message = messages[-1]\n",
    "    if last_message.content.startswith(\"Answer:\"):\n",
    "        return END\n",
    "    if last_message.content.startswith(\"Error:\"):\n",
    "        return \"query_gen\"\n",
    "    else:\n",
    "        return \"correct_query\"\n",
    "\n",
    "\n",
    "# 쿼리 생성 노드 정의\n",
    "def query_gen_node(state: State):\n",
    "    message = query_gen.invoke(state)\n",
    "\n",
    "    # LLM이 잘못된 도구를 호출할 경우 오류 메시지를 반환\n",
    "    tool_messages = []\n",
    "    message.pretty_print()\n",
    "    if message.tool_calls:\n",
    "        for tc in message.tool_calls:\n",
    "            if tc[\"name\"] != \"SubmitFinalAnswer\":\n",
    "                tool_messages.append(\n",
    "                    ToolMessage(\n",
    "                        content=f\"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.\",\n",
    "                        tool_call_id=tc[\"id\"],\n",
    "                    )\n",
    "                )\n",
    "    else:\n",
    "        tool_messages = []\n",
    "    return {\"messages\": [message] + tool_messages}\n",
    "\n",
    "\n",
    "# 쿼리 생성 노드 추가\n",
    "workflow.add_node(\"query_gen\", query_gen_node)\n",
    "\n",
    "# 쿼리를 실행하기 전에 모델로 점검하는 노드 추가\n",
    "workflow.add_node(\"correct_query\", model_check_query)\n",
    "\n",
    "# 쿼리를 실행하기 위한 노드 추가\n",
    "workflow.add_node(\"execute_query\", create_tool_node_with_fallback([db_query_tool]))\n",
    "\n",
    "# 노드 간의 엣지 지정\n",
    "workflow.add_edge(START, \"first_tool_call\")\n",
    "workflow.add_edge(\"first_tool_call\", \"list_tables_tool\")\n",
    "workflow.add_edge(\"list_tables_tool\", \"model_get_schema\")\n",
    "workflow.add_edge(\"model_get_schema\", \"get_schema_tool\")\n",
    "workflow.add_edge(\"get_schema_tool\", \"query_gen\")\n",
    "workflow.add_conditional_edges(\n",
    "    \"query_gen\",\n",
    "    should_continue,\n",
    ")\n",
    "workflow.add_edge(\"correct_query\", \"execute_query\")\n",
    "workflow.add_edge(\"execute_query\", \"query_gen\")\n",
    "\n",
    "# 실행 가능한 워크플로우로 컴파일\n",
    "app = workflow.compile(checkpointer=MemorySaver())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7d7f2d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.graphs import visualize_graph\n",
    "\n",
    "visualize_graph(app, xray=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31e48ae3",
   "metadata": {},
   "source": [
    "## 그래프 실행\n",
    "\n",
    "에이전트를 실행하여 SQL 데이터베이스와 상호작용하는 전체 프로세스를 진행합니다.\n",
    "\n",
    "에이전트는 사용자의 질문에 따라 데이터베이스에서 정보를 검색하고, 쿼리를 생성 및 실행하여 결과를 반환합니다. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2ca79a8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import RunnableConfig\n",
    "from langchain_teddynote.messages import random_uuid, invoke_graph, stream_graph\n",
    "from langchain_core.messages import HumanMessage\n",
    "from langgraph.errors import GraphRecursionError\n",
    "\n",
    "\n",
    "def run_graph(\n",
    "    message: str, recursive_limit: int = 30, node_names=[], stream: bool = False\n",
    "):\n",
    "    # config 설정(재귀 최대 횟수, thread_id)\n",
    "    config = RunnableConfig(\n",
    "        recursion_limit=recursive_limit, configurable={\"thread_id\": random_uuid()}\n",
    "    )\n",
    "\n",
    "    # 질문 입력\n",
    "    inputs = {\n",
    "        \"messages\": [HumanMessage(content=message)],\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        if stream:\n",
    "            # 그래프 실행\n",
    "            stream_graph(app, inputs, config, node_names=node_names)\n",
    "        else:\n",
    "            invoke_graph(app, inputs, config, node_names=node_names)\n",
    "        output = app.get_state(config).values\n",
    "        return output\n",
    "    except GraphRecursionError as recursion_error:\n",
    "        print(f\"GraphRecursionError: {recursion_error}\")\n",
    "        output = app.get_state(config).values\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b51d2282",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = run_graph(\n",
    "    \"Andrew Adam 직원의 인적정보를 모두 조회해줘\",\n",
    "    stream=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf2d13fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = run_graph(\n",
    "    \"2009년도에 어느 국가의 고객이 가장 많이 지출했을까요? 그리고 얼마를 지출했을까요? 한글로 답변하세요.\",\n",
    "    stream=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe744388",
   "metadata": {},
   "source": [
    "## LangSmith Evaluator 를 활용한 SQL Agent 평가\n",
    "\n",
    "이제 생성한 Agent 의 SQL 쿼리 응답을 평가합니다. 쿼리 응답을 평가하기 위한 평가용 데이터셋을 생성합니다.\n",
    "\n",
    "다음으로는 평가자를 정의하고 평가를 진행합니다.\n",
    "\n",
    "이때 활용하는 평가자는 LLM-as-judge 이며, 사용하는 프롬프트는 기본 hub 에서 제공하는 프롬프트를 활용합니다.\n",
    "\n",
    "다만, 보다 정확한 평가를 위해서 각자 프롬프트를 튜닝하여 사용하는 것을 권장합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "63565cc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langsmith import Client\n",
    "\n",
    "# 클라이언트 초기화\n",
    "client = Client()\n",
    "\n",
    "# 데이터셋 생성 및 업로드\n",
    "examples = [\n",
    "    (\n",
    "        \"Which country's customers spent the most? And how much did they spend?\",\n",
    "        \"The country whose customers spent the most is the USA, with a total spending of 523.06.\",\n",
    "    ),\n",
    "    (\n",
    "        \"What was the most purchased track of 2013?\",\n",
    "        \"The most purchased track of 2013 was Hot Girl.\",\n",
    "    ),\n",
    "    (\n",
    "        \"How many albums does the artist Led Zeppelin have?\",\n",
    "        \"Led Zeppelin has 14 albums\",\n",
    "    ),\n",
    "    (\n",
    "        \"What is the total price for the album “Big Ones”?\",\n",
    "        \"The total price for the album 'Big Ones' is 14.85\",\n",
    "    ),\n",
    "    (\n",
    "        \"Which sales agent made the most in sales in 2009?\",\n",
    "        \"Steve Johnson made the most sales in 2009\",\n",
    "    ),\n",
    "]\n",
    "\n",
    "dataset_name = \"SQL Agent Response\"\n",
    "\n",
    "if not client.has_dataset(dataset_name=dataset_name):\n",
    "    dataset = client.create_dataset(dataset_name=dataset_name)\n",
    "    inputs, outputs = zip(\n",
    "        *[({\"input\": text}, {\"output\": label}) for text, label in examples]\n",
    "    )\n",
    "    client.create_examples(inputs=inputs, outputs=outputs, dataset_id=dataset.id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e01f99b",
   "metadata": {},
   "source": [
    "다음으로는 우리가 만든 에이전트의 SQL 쿼리 응답을 예측하기 위한 함수를 정의합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a446fb2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 에이전트의 SQL 쿼리 응답을 예측하기 위한 함수 정의\n",
    "def predict_sql_agent_answer(example: dict):\n",
    "    \"\"\"Use this for answer evaluation\"\"\"\n",
    "    config = RunnableConfig(configurable={\"thread_id\": random_uuid()})\n",
    "\n",
    "    inputs = {\n",
    "        \"messages\": [HumanMessage(content=example[\"input\"])],\n",
    "    }\n",
    "    # 그래프를 실행하여 메시지 결과 조회\n",
    "    messages = app.invoke(inputs, config)\n",
    "    answer = messages[\"messages\"][-1].content\n",
    "    # 결과 반환\n",
    "    return {\"response\": answer}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfecdfb1",
   "metadata": {},
   "source": [
    "SQL 쿼리 응답을 평가하기 위한 프롬프트와 평가자(LLM-as-judge) 를 정의합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f444566e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain import hub\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# Grade prompt\n",
    "grade_prompt_answer_accuracy = hub.pull(\"langchain-ai/rag-answer-vs-reference\")\n",
    "\n",
    "\n",
    "# 답변 평가자 LLM-as-judge 정의\n",
    "def answer_evaluator(run, example) -> dict:\n",
    "    # input: 질문\n",
    "    input_question = example.inputs[\"input\"]\n",
    "    # output: 참조 답변\n",
    "    reference = example.outputs[\"output\"]\n",
    "    # 예측 답변\n",
    "    prediction = run.outputs[\"response\"]\n",
    "\n",
    "    # LLM 평가자 초기화\n",
    "    llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
    "    answer_grader = grade_prompt_answer_accuracy | llm\n",
    "\n",
    "    # 평가자 실행\n",
    "    score = answer_grader.invoke(\n",
    "        {\n",
    "            \"question\": input_question,\n",
    "            \"correct_answer\": reference,\n",
    "            \"student_answer\": prediction,\n",
    "        }\n",
    "    )\n",
    "    score = score[\"Score\"]\n",
    "\n",
    "    # 점수 반환\n",
    "    return {\"key\": \"answer_v_reference_score\", \"score\": score}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72159566",
   "metadata": {},
   "source": [
    "이제, 평가를 수행하고 결과를 확인합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005efae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langsmith.evaluation import evaluate\n",
    "\n",
    "# 평가용 데이터셋 이름\n",
    "dataset_name = \"SQL Agent Response\"\n",
    "\n",
    "try:\n",
    "    # 평가 진행\n",
    "    experiment_results = evaluate(\n",
    "        predict_sql_agent_answer,  # 평가시 활용할 예측 함수\n",
    "        data=dataset_name,  # 평가용 데이터셋 이름\n",
    "        evaluators=[answer_evaluator],  # 평가자 목록\n",
    "        num_repetitions=3,  # 실험 반복 횟수 설정\n",
    "        experiment_prefix=\"sql-agent-eval\",\n",
    "        metadata={\"version\": \"chinook db, sql-agent-eval: gpt-4o\"},  # 실험 메타데이터\n",
    "    )\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "239741fd",
   "metadata": {},
   "source": [
    "평가 결과는 생성된 URL 에서 각자 확인할 수 있습니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16c9d2ff",
   "metadata": {},
   "source": [
    "![](./assets/langgraph-sql-agent-evaluation.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain-kr-lwwSZlnu-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
